﻿/*
 * *******************************************************
 *
 * 作者：hzy
 *
 * 开源地址：https://gitee.com/hzy6
 *
 * *******************************************************
 */


/*
 * *******************************************************
 *
 * 作者：hzy
 *
 * 开源地址：https://gitee.com/hzy6
 *
 * *******************************************************
 */

using DbContext = Microsoft.EntityFrameworkCore.DbContext;

namespace HZY.Framework.Repository.EntityFramework.Repositories.Impl;

/// <summary>
/// 工作单元
/// </summary>
public class UnitOfWorkImpl<TDbContext> : IUnitOfWork where TDbContext : DbContext
{
    private bool _saveState = true;
    private readonly TDbContext _dbContext;
    /// <summary>
    /// freesql 对象
    /// </summary>
    protected readonly IFreeSql freeSqlOrm;

    /// <summary>
    /// 工作单元 构造
    /// </summary>
    /// <param name="dbContext"></param>
    public UnitOfWorkImpl(TDbContext dbContext)
    {
        _dbContext = dbContext;

        var typeTDbContext = typeof(TDbContext);
        if (RepositoryEntityFrameworkExtensions.FreeSqlOrmList.TryGetValue(typeTDbContext, out IFreeSql? value))
        {
            freeSqlOrm = value;
        }
    }

    /// <summary>
    /// 获取延迟保存状态
    /// </summary>
    /// <returns></returns>
    public virtual bool GetDelaySaveState() => _saveState;

    /// <summary>
    /// 设置延迟保存状态
    /// </summary>
    /// <param name="saveSate"></param>
    public virtual void SetDelaySaveState(bool saveSate) => _saveState = saveSate;

    /// <summary>
    /// 打开延迟提交
    /// </summary>
    public virtual void CommitDelayStart() => _saveState = false;

    /// <summary>
    /// 延迟提交结束
    /// </summary>
    /// <returns></returns>
    public virtual int CommitDelayEnd(int retryCount = 10)
    {
        SetDelaySaveState(true);

        var result = 0;

        var saved = false;

        var retry = 0;

        while (!saved)
        {
            try
            {
                // Attempt to save changes to the database
                result = this.SaveChanges();
                saved = true;
            }
            catch (DbUpdateConcurrencyException ex)
            {
                if (retry >= retryCount)
                {
                    throw;
                }

                retry++;

                foreach (var entry in ex.Entries)
                {
                    //var proposedValues = entry.CurrentValues;
                    var databaseValues = entry.GetDatabaseValues();

                    //foreach (var property in proposedValues.Properties)
                    //{
                    //    var proposedValue = proposedValues[property];
                    //    var databaseValue = databaseValues[property];

                    //    // TODO: decide which value should be written to database
                    //    // proposedValues[property] = <value to be saved>;
                    //}

                    // Refresh original values to bypass next concurrency check
                    if (databaseValues is not null)
                        entry.OriginalValues.SetValues(databaseValues);
                }
            }
        }

        return result;
    }

    /// <summary>
    /// 延迟提交结束
    /// </summary>
    /// <returns></returns>
    public virtual async Task<int> CommitDelayEndAsync(int retryCount = 10)
    {
        SetDelaySaveState(true);

        var result = 0;

        var saved = false;

        var retry = 0;

        while (!saved)
        {
            try
            {
                // Attempt to save changes to the database
                result = await this.SaveChangesAsync();
                saved = true;
            }
            catch (DbUpdateConcurrencyException ex)
            {
                if (retry >= retryCount)
                {
                    throw;
                }

                retry++;

                foreach (var entry in ex.Entries)
                {
                    //var proposedValues = entry.CurrentValues;
                    var databaseValues = entry.GetDatabaseValues();

                    //foreach (var property in proposedValues.Properties)
                    //{
                    //    var proposedValue = proposedValues[property];
                    //    var databaseValue = databaseValues[property];

                    //    // TODO: decide which value should be written to database
                    //    // proposedValues[property] = <value to be saved>;
                    //}

                    // Refresh original values to bypass next concurrency check
                    if (databaseValues is not null)
                        entry.OriginalValues.SetValues(databaseValues);
                }
            }
        }

        return result;
    }

    /// <summary>
    /// 开始事务
    /// </summary>
    /// <returns></returns>
    public virtual IDbContextTransaction BeginTransaction() => _dbContext.Database.BeginTransaction();

    /// <summary>
    /// 开始事务
    /// </summary>
    /// <returns></returns>
    public virtual Task<IDbContextTransaction> BeginTransactionAsync() => _dbContext.Database.BeginTransactionAsync();

    /// <summary>
    /// 获取当前 dbContext 事务
    /// </summary>
    public virtual IDbContextTransaction? CurrentDbContextTransaction => _dbContext.Database.CurrentTransaction;

    /// <summary>
    /// 获取当前 事务
    /// </summary>
    public virtual IDbTransaction? CurrentDbTransaction => _dbContext.Database.CurrentTransaction != null
        ? GetDbTransaction(_dbContext.Database.CurrentTransaction)
        : null;

    /// <summary>
    /// freesql 对象
    /// </summary>
    public IFreeSql FreeSqlOrm => freeSqlOrm;

    /// <summary>
    /// 获取当前 事务 根据 IDbContextTransaction 事务
    /// </summary>
    /// <param name="dbContextTransaction"></param>
    /// <returns></returns>
    public virtual IDbTransaction GetDbTransaction(IDbContextTransaction dbContextTransaction)
    {
        return dbContextTransaction.GetDbTransaction();
    }

    /// <summary>
    /// 获取 dbSet 对象
    /// </summary>
    /// <typeparam name="T"></typeparam>
    /// <returns></returns>
    public virtual Microsoft.EntityFrameworkCore.DbSet<T> DbSet<T>() where T : class, new() => _dbContext.Set<T>();

    /// <summary>
    /// 保存数据
    /// </summary>
    /// <returns></returns>
    public virtual int SaveChanges()
    {
        return GetDelaySaveState() ? _dbContext.SaveChanges() : 0;
    }

    /// <summary>
    /// 保存数据
    /// </summary>
    /// <param name="acceptAllChangesOnSuccess"></param>
    /// <returns></returns>
    public virtual int SaveChanges(bool acceptAllChangesOnSuccess)
    {
        return GetDelaySaveState() ? _dbContext.SaveChanges(acceptAllChangesOnSuccess) : 0;
    }

    /// <summary>
    /// 保存数据
    /// </summary>
    /// <returns></returns>
    public virtual Task<int> SaveChangesAsync()
    {
        return GetDelaySaveState() ? _dbContext.SaveChangesAsync() : Task.FromResult(0);
    }

    /// <summary>
    /// 保存数据
    /// </summary>
    /// <param name="cancellationToken"></param>
    /// <returns></returns>
    public virtual Task<int> SaveChangesAsync(CancellationToken cancellationToken = default)
    {
        return GetDelaySaveState() ? _dbContext.SaveChangesAsync(cancellationToken) : Task.FromResult(0);
    }

    /// <summary>
    /// 保存数据
    /// </summary>
    /// <param name="acceptAllChangesOnSuccess"></param>
    /// <param name="cancellationToken"></param>
    /// <returns></returns>
    public virtual Task<int> SaveChangesAsync(bool acceptAllChangesOnSuccess,
        CancellationToken cancellationToken = new CancellationToken())
    {
        return GetDelaySaveState()
            ? _dbContext.SaveChangesAsync(acceptAllChangesOnSuccess, cancellationToken)
            : Task.FromResult(0);
    }


    /// <summary>
    /// 显示和手动调用的释放资源函数
    /// </summary>
    /// <returns></returns>
    /// <exception cref="NotImplementedException"></exception>
    public async ValueTask DisposeAsync()
    {
        await DisposeAllAsync();

        //手动调用了Dispose释放资源，那么析构函数就是不必要的了，这里阻止GC调用析构函数
        GC.SuppressFinalize(this);

        GC.WaitForPendingFinalizers();
        GC.Collect();
    }

    /// <summary>
    /// 显示和手动调用的释放资源函数
    /// </summary>
    public virtual void Dispose()
    {
        //调用带参数的Dispose方法，释放托管和非托管资源
        DisposeAll();

        //手动调用了Dispose释放资源，那么析构函数就是不必要的了，这里阻止GC调用析构函数
        GC.SuppressFinalize(this);

        GC.WaitForPendingFinalizers();
        GC.Collect();
    }

    /// <summary>
    /// 资源释放函数
    /// </summary>
    protected virtual void DisposeAll()
    {
        // TODO:在这里加入清理"托管资源"的代码，应该是xxx.Dispose();
        _dbContext.Dispose();

        // TODO:在这里加入清理"非托管资源"的代码
    }

    /// <summary>
    /// 资源释放函数
    /// </summary>
    protected virtual async Task DisposeAllAsync()
    {
        // TODO:在这里加入清理"托管资源"的代码，应该是xxx.Dispose();
        await _dbContext.DisposeAsync();

        // TODO:在这里加入清理"非托管资源"的代码
    }

    /// <summary>
    /// 供GC调用的析构函数
    /// </summary>
    ~UnitOfWorkImpl()
    {
        DisposeAll(); //释放资源
    }


    #region 原生 sql 操作

    /// <summary>
    /// 查询根据sql语句
    /// EFCore 原生sql查询
    /// </summary>
    /// <returns> IQueryable </returns>
    public virtual IQueryable<T> QueryableBySql<T>(string sql, params object[] parameters) where T : class, new()
    {
        return this.DbSet<T>().FromSqlRaw(sql, parameters);
    }

    /// <summary>
    /// 根据 sql 查询表格
    /// </summary>
    /// <param name="sql"></param>
    /// <param name="parameters"></param>
    /// <returns></returns>
    public virtual DataTable? QueryDataTableBySql(string sql, params object[] parameters)
    {
        return _dbContext.Database.QueryDataTableBySql(sql, parameters);
    }

    /// <summary>
    /// 根据 sql 查询表格
    /// </summary>
    /// <param name="sql"></param>
    /// <param name="parameters"></param>
    /// <returns></returns>
    public virtual Task<DataTable?> QueryDataTableBySqlAsync(string sql, params object[] parameters)
    {
        return _dbContext.Database.QueryDataTableBySqlAsync(sql, parameters);
    }

    /// <summary>
    /// 根据 sql 查询字典集合
    /// </summary>
    /// <param name="sql"></param>
    /// <param name="parameters"></param>
    /// <returns></returns>
    public virtual List<Dictionary<string, object?>>? QueryDicBySql(string sql, params object[] parameters)
    {
        return _dbContext.Database.QueryDicBySql(sql, parameters);
    }

    /// <summary>
    /// 根据 sql 查询字典集合
    /// </summary>
    /// <param name="sql"></param>
    /// <param name="parameters"></param>
    /// <returns></returns>
    public virtual Task<List<Dictionary<string, object?>>?> QueryDicBySqlAsync(string sql, params object[] parameters)
    {
        return _dbContext.Database.QueryDicBySqlAsync(sql, parameters);
    }

    /// <summary>
    /// 查询根据sql语句
    /// </summary>
    /// <param name="sql"></param>
    /// <param name="parameters"></param>
    /// <returns></returns>
    public virtual List<T>? QueryBySql<T>(string sql, params object[] parameters) where T : class, new()
    {
        return _dbContext.Database.QueryBySql<T>(sql, parameters);
    }

    /// <summary>
    /// 查询根据sql语句
    /// </summary>
    /// <param name="sql"></param>
    /// <param name="parameters"></param>
    /// <returns></returns>
    public virtual Task<List<T>?> QueryBySqlAsync<T>(string sql, params object[] parameters) where T : class, new()
    {
        return _dbContext.Database.QueryBySqlAsync<T>(sql, parameters);
    }

    /// <summary>
    /// 查询根据sql返回单个值
    /// </summary>
    /// <param name="sql"></param>
    /// <param name="parameters"></param>
    /// <returns></returns>
    public virtual object? QuerySingleBySql(string sql, params object[] parameters)
    {
        return _dbContext.Database.QuerySingleBySql(sql, parameters);
    }

    /// <summary>
    /// 查询根据sql返回单个值
    /// </summary>
    /// <param name="sql"></param>
    /// <param name="parameters"></param>
    /// <returns></returns>
    public virtual Task<object?> QuerySingleBySqlAsync(string sql, params object[] parameters)
    {
        return _dbContext.Database.QuerySingleBySqlAsync(sql, parameters);
    }

    /// <summary>
    /// 查询根据sql返回单个值
    /// </summary>
    /// <typeparam name="TResult"></typeparam>
    /// <param name="sql"></param>
    /// <param name="parameters"></param>
    /// <returns></returns>
    public virtual TResult QuerySingleBySql<TResult>(string sql, params object[] parameters)
        where TResult : struct
    {
        return _dbContext.Database.QuerySingleBySql<TResult>(sql, parameters);
    }

    /// <summary>
    /// 查询根据sql返回单个值
    /// </summary>
    /// <typeparam name="TResult"></typeparam>
    /// <param name="sql"></param>
    /// <param name="parameters"></param>
    /// <returns></returns>
    public virtual Task<TResult> QuerySingleBySqlAsync<TResult>(string sql, params object[] parameters)
        where TResult : struct
    {
        return _dbContext.Database.QuerySingleBySqlAsync<TResult>(sql, parameters);
    }

    #endregion

}